{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TensorFlow 1.0 自定义estimators"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.13.1\n",
      "sys.version_info(major=3, minor=6, micro=4, releaselevel='final', serial=0)\n",
      "matplotlib 2.2.3\n",
      "numpy 1.16.2\n",
      "pandas 0.22.0\n",
      "sklearn 0.19.1\n",
      "tensorflow 1.13.1\n",
      "tensorflow._api.v1.keras 2.2.4-tf\n"
     ]
    }
   ],
   "source": [
    "# 导入\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import sklearn\n",
    "import pandas as pd\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "\n",
    "print(tf.__version__)\n",
    "print(sys.version_info)\n",
    "for module in mpl,np,pd,sklearn,tf,keras:\n",
    "    print(module.__name__,module.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 读取数据，构建数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   survived     sex   age  n_siblings_spouses  parch     fare  class     deck  \\\n",
      "0         0    male  22.0                   1      0   7.2500  Third  unknown   \n",
      "1         1  female  38.0                   1      0  71.2833  First        C   \n",
      "2         1  female  26.0                   0      0   7.9250  Third  unknown   \n",
      "3         1  female  35.0                   1      0  53.1000  First        C   \n",
      "4         0    male  28.0                   0      0   8.4583  Third  unknown   \n",
      "\n",
      "   embark_town alone  \n",
      "0  Southampton     n  \n",
      "1    Cherbourg     n  \n",
      "2  Southampton     y  \n",
      "3  Southampton     n  \n",
      "4   Queenstown     y  \n",
      "   survived     sex   age  n_siblings_spouses  parch     fare   class  \\\n",
      "0         0    male  35.0                   0      0   8.0500   Third   \n",
      "1         0    male  54.0                   0      0  51.8625   First   \n",
      "2         1  female  58.0                   0      0  26.5500   First   \n",
      "3         1  female  55.0                   0      0  16.0000  Second   \n",
      "4         1    male  34.0                   0      0  13.0000  Second   \n",
      "\n",
      "      deck  embark_town alone  \n",
      "0  unknown  Southampton     y  \n",
      "1        E  Southampton     y  \n",
      "2        C  Southampton     y  \n",
      "3  unknown  Southampton     y  \n",
      "4        D  Southampton     y  \n"
     ]
    }
   ],
   "source": [
    "train_file = 'data/titanic/train.csv'\n",
    "eval_file = 'data/titanic/eval.csv'\n",
    "\n",
    "train_df = pd.read_csv(train_file)\n",
    "eval_df = pd.read_csv(eval_file)\n",
    "\n",
    "print(train_df.head())\n",
    "print(eval_df.head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      sex   age  n_siblings_spouses  parch     fare  class     deck  \\\n",
      "0    male  22.0                   1      0   7.2500  Third  unknown   \n",
      "1  female  38.0                   1      0  71.2833  First        C   \n",
      "2  female  26.0                   0      0   7.9250  Third  unknown   \n",
      "3  female  35.0                   1      0  53.1000  First        C   \n",
      "4    male  28.0                   0      0   8.4583  Third  unknown   \n",
      "\n",
      "   embark_town alone  \n",
      "0  Southampton     n  \n",
      "1    Cherbourg     n  \n",
      "2  Southampton     y  \n",
      "3  Southampton     n  \n",
      "4   Queenstown     y  \n",
      "      sex   age  n_siblings_spouses  parch     fare   class     deck  \\\n",
      "0    male  35.0                   0      0   8.0500   Third  unknown   \n",
      "1    male  54.0                   0      0  51.8625   First        E   \n",
      "2  female  58.0                   0      0  26.5500   First        C   \n",
      "3  female  55.0                   0      0  16.0000  Second  unknown   \n",
      "4    male  34.0                   0      0  13.0000  Second        D   \n",
      "\n",
      "   embark_town alone  \n",
      "0  Southampton     y  \n",
      "1  Southampton     y  \n",
      "2  Southampton     y  \n",
      "3  Southampton     y  \n",
      "4  Southampton     y  \n",
      "0    0\n",
      "1    1\n",
      "2    1\n",
      "3    1\n",
      "4    0\n",
      "Name: survived, dtype: int64\n",
      "0    0\n",
      "1    0\n",
      "2    1\n",
      "3    1\n",
      "4    1\n",
      "Name: survived, dtype: int64\n"
     ]
    }
   ],
   "source": [
    "# 将survived字段取出，赋给label\n",
    "y_train = train_df.pop('survived')\n",
    "y_eval = eval_df.pop('survived')\n",
    "\n",
    "print(train_df.head())\n",
    "print(eval_df.head())\n",
    "print(y_train.head())\n",
    "print(y_eval.head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sex ['male' 'female']\n",
      "n_siblings_spouses [1 0 3 4 2 5 8]\n",
      "parch [0 1 2 5 3 4]\n",
      "class ['Third' 'First' 'Second']\n",
      "deck ['unknown' 'C' 'G' 'A' 'B' 'D' 'F' 'E']\n",
      "embark_town ['Southampton' 'Cherbourg' 'Queenstown' 'unknown']\n",
      "alone ['n' 'y']\n"
     ]
    }
   ],
   "source": [
    "# 将数据分成两类：离散和连续\n",
    "categorical_columns = ['sex', 'n_siblings_spouses', 'parch', 'class',\n",
    "                       'deck', 'embark_town', 'alone']\n",
    "numeric_columns = ['age', 'fare']\n",
    "\n",
    "feature_columns = []\n",
    "for categorical_column in categorical_columns:\n",
    "    # 获取离散值所有可能值\n",
    "    vocab = train_df[categorical_column].unique()\n",
    "    print(categorical_column, vocab)\n",
    "    # 对离散值封装，做onehot编码，再加入feature_columns中\n",
    "    feature_columns.append(\n",
    "        tf.feature_column.indicator_column(\n",
    "            tf.feature_column.categorical_column_with_vocabulary_list(\n",
    "                categorical_column, vocab)))\n",
    "    \n",
    "for numeric_column in numeric_columns:\n",
    "    feature_columns.append(\n",
    "        tf.feature_column.numeric_column(numeric_column, dtype=tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 从pandas dataframe中构建dataset\n",
    "def make_dataset(data_df, label_df, epochs = 10, shuffle = True, batch_size = 32):\n",
    "    dataset = tf.data.Dataset.from_tensor_slices((dict(data_df), label_df))\n",
    "    if shuffle:\n",
    "        dataset = dataset.shuffle(10000)\n",
    "    dataset = dataset.repeat(epochs).batch(batch_size)\n",
    "    return dataset.make_one_shot_iterator().get_next()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 自定义estimator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir = 'customized_estimator'\n",
    "if not os.path.exists(output_dir):\n",
    "    os.mkdir(output_dir)\n",
    "    \n",
    "# 定义estimator所需要的函数\n",
    "# mode:model runtime status: Train,Eval,Predict\n",
    "def model_fn(features, labels, mode, params):\n",
    "    # 定义网络\n",
    "    input_for_next_layer = tf.feature_column.input_layer(\n",
    "        features, params['feature_columns'])\n",
    "    for n_unit in params['hidden_units']:\n",
    "        input_for_next_layer = tf.layers.dense(\n",
    "            input_for_next_layer, units=n_unit, activation=tf.nn.relu)\n",
    "    logits = tf.layers.dense(\n",
    "        input_for_next_layer, params['n_classes'], activation=None)\n",
    "    predicted_classes = tf.argmax(logits, 1)\n",
    "    \n",
    "    # 根据不同的mode，返回不同的值\n",
    "    if mode == tf.estimator.ModeKeys.PREDICT:\n",
    "        # 定义预测的返回值\n",
    "        predictions = {\n",
    "            'class_ids': predicted_classes[:, tf.newaxis],\n",
    "            'probabilities': tf.nn.softmax(logits),\n",
    "            'logits': logits\n",
    "        }\n",
    "        return tf.estimator.EstimatorSpec(mode, predictions=predictions)\n",
    "    \n",
    "    # 计算指标\n",
    "    loss = tf.losses.sparse_softmax_cross_entropy(labels = labels, logits = logits)\n",
    "    accuracy = tf.metrics.accuracy(\n",
    "        labels = labels, predictions = predicted_classes, name = 'acc_op')\n",
    "    metrics = {'accuracy': accuracy}\n",
    "    \n",
    "    if mode == tf.estimator.ModeKeys.EVAL:\n",
    "        return tf.estimator.EstimatorSpec(mode, loss = loss, eval_metric_ops = metrics)\n",
    "\n",
    "    optimizer = tf.train.AdamOptimizer()\n",
    "    train_op = optimizer.minimize(\n",
    "        loss, global_step = tf.train.get_global_step())\n",
    "    \n",
    "    if mode == tf.estimator.ModeKeys.TRAIN:\n",
    "        return tf.estimator.EstimatorSpec(mode, loss = loss, train_op = train_op)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using default config.\n",
      "INFO:tensorflow:Using config: {'_model_dir': 'customized_estimator', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n",
      "graph_options {\n",
      "  rewrite_options {\n",
      "    meta_optimizer_iterations: ONE\n",
      "  }\n",
      "}\n",
      ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fdb36af7898>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n",
      "WARNING:tensorflow:From /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Colocations handled automatically by placer.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "WARNING:tensorflow:From /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column.py:205: NumericColumn._get_dense_tensor (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed after 2018-11-30.\n",
      "Instructions for updating:\n",
      "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n",
      "WARNING:tensorflow:From /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column.py:2121: NumericColumn._transform_feature (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed after 2018-11-30.\n",
      "Instructions for updating:\n",
      "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n",
      "WARNING:tensorflow:From /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column_v2.py:2703: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.cast instead.\n",
      "WARNING:tensorflow:From /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column.py:206: NumericColumn._variable_shape (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed after 2018-11-30.\n",
      "Instructions for updating:\n",
      "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n",
      "WARNING:tensorflow:From /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column.py:205: IndicatorColumn._get_dense_tensor (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed after 2018-11-30.\n",
      "Instructions for updating:\n",
      "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n",
      "WARNING:tensorflow:From /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column.py:2121: IndicatorColumn._transform_feature (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed after 2018-11-30.\n",
      "Instructions for updating:\n",
      "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n",
      "WARNING:tensorflow:From /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column_v2.py:4295: VocabularyListCategoricalColumn._get_sparse_tensors (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed after 2018-11-30.\n",
      "Instructions for updating:\n",
      "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n",
      "WARNING:tensorflow:From /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column.py:2121: VocabularyListCategoricalColumn._transform_feature (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed after 2018-11-30.\n",
      "Instructions for updating:\n",
      "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n",
      "WARNING:tensorflow:From /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/ops/lookup_ops.py:1137: to_int64 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.cast instead.\n",
      "WARNING:tensorflow:From /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column_v2.py:4266: IndicatorColumn._variable_shape (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed after 2018-11-30.\n",
      "Instructions for updating:\n",
      "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n",
      "WARNING:tensorflow:From /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column_v2.py:4321: VocabularyListCategoricalColumn._num_buckets (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed after 2018-11-30.\n",
      "Instructions for updating:\n",
      "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n",
      "WARNING:tensorflow:From <ipython-input-7-0c6ccbe3f8b6>:13: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.dense instead.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 0 into customized_estimator/model.ckpt.\n",
      "INFO:tensorflow:loss = 1.3071299, step = 0\n",
      "INFO:tensorflow:global_step/sec: 245.894\n",
      "INFO:tensorflow:loss = 0.48425955, step = 100 (0.408 sec)\n",
      "INFO:tensorflow:global_step/sec: 329.592\n",
      "INFO:tensorflow:loss = 0.55831015, step = 200 (0.303 sec)\n",
      "INFO:tensorflow:global_step/sec: 331.119\n",
      "INFO:tensorflow:loss = 0.5708145, step = 300 (0.304 sec)\n",
      "INFO:tensorflow:global_step/sec: 338.001\n",
      "INFO:tensorflow:loss = 0.37386686, step = 400 (0.294 sec)\n",
      "INFO:tensorflow:global_step/sec: 328.633\n",
      "INFO:tensorflow:loss = 0.31370935, step = 500 (0.304 sec)\n",
      "INFO:tensorflow:global_step/sec: 337.292\n",
      "INFO:tensorflow:loss = 0.24561787, step = 600 (0.296 sec)\n",
      "INFO:tensorflow:global_step/sec: 332.143\n",
      "INFO:tensorflow:loss = 0.39784604, step = 700 (0.301 sec)\n",
      "INFO:tensorflow:global_step/sec: 328.185\n",
      "INFO:tensorflow:loss = 0.5475549, step = 800 (0.306 sec)\n",
      "INFO:tensorflow:global_step/sec: 334.18\n",
      "INFO:tensorflow:loss = 0.49038672, step = 900 (0.298 sec)\n",
      "INFO:tensorflow:global_step/sec: 312.21\n",
      "INFO:tensorflow:loss = 0.24679664, step = 1000 (0.320 sec)\n",
      "INFO:tensorflow:global_step/sec: 323.758\n",
      "INFO:tensorflow:loss = 0.48531774, step = 1100 (0.309 sec)\n",
      "INFO:tensorflow:global_step/sec: 323.186\n",
      "INFO:tensorflow:loss = 0.35774, step = 1200 (0.311 sec)\n",
      "INFO:tensorflow:global_step/sec: 330.6\n",
      "INFO:tensorflow:loss = 0.31498316, step = 1300 (0.301 sec)\n",
      "INFO:tensorflow:global_step/sec: 327.487\n",
      "INFO:tensorflow:loss = 0.38564783, step = 1400 (0.306 sec)\n",
      "INFO:tensorflow:global_step/sec: 313.389\n",
      "INFO:tensorflow:loss = 0.3783471, step = 1500 (0.319 sec)\n",
      "INFO:tensorflow:global_step/sec: 330.789\n",
      "INFO:tensorflow:loss = 0.208392, step = 1600 (0.302 sec)\n",
      "INFO:tensorflow:global_step/sec: 336.623\n",
      "INFO:tensorflow:loss = 0.3715057, step = 1700 (0.299 sec)\n",
      "INFO:tensorflow:global_step/sec: 339.639\n",
      "INFO:tensorflow:loss = 0.49013346, step = 1800 (0.293 sec)\n",
      "INFO:tensorflow:global_step/sec: 343.029\n",
      "INFO:tensorflow:loss = 0.48774076, step = 1900 (0.291 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 1960 into customized_estimator/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 0.23161016.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow_estimator.python.estimator.estimator.Estimator at 0x7fdb36af76d8>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 使用定义的函数自定义estimator\n",
    "estimator = tf.estimator.Estimator(\n",
    "    model_fn=model_fn,\n",
    "    model_dir=output_dir,\n",
    "    params={\n",
    "        'feature_columns': feature_columns,\n",
    "        'hidden_units': [100, 100],\n",
    "        'n_classes': 2\n",
    "    })\n",
    "\n",
    "estimator.train(input_fn = lambda : make_dataset(train_df, y_train, epochs = 100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Starting evaluation at 2020-02-07T15:34:47Z\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "WARNING:tensorflow:From /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use standard file APIs to check for files with this prefix.\n",
      "INFO:tensorflow:Restoring parameters from customized_estimator/model.ckpt-1960\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Finished evaluation at 2020-02-07-15:34:47\n",
      "INFO:tensorflow:Saving dict for global step 1960: accuracy = 0.780303, global_step = 1960, loss = 0.5266627\n",
      "INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1960: customized_estimator/model.ckpt-1960\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'accuracy': 0.780303, 'global_step': 1960, 'loss': 0.5266627}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "estimator.evaluate(lambda : make_dataset(eval_df, y_eval, epochs = 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
